{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "originalKey": "2c421274-e807-4d93-8b0e-d23afdb49a2d",
    "showInput": false
   },
   "source": [
    "## Composite Bayesian Optimization with Multi-Task Gaussian Processes\n",
    "\n",
    "In this tutorial, we'll be describing how to perform multi-task Bayesian optimization over composite functions. In these types of problems, there are several related outputs, and an overall easy to evaluate objective function that we wish to maximize.\n",
    "\n",
    "**Multi-task Bayesian Optimization** was first proposed by [Swersky et al, NeurIPS, '13](https://papers.neurips.cc/paper/2013/hash/f33ba15effa5c10e873bf3842afb46a6-Abstract.html) in the context of fast hyper-parameter tuning for neural network models; however, we demonstrate a more advanced use-case of **[composite Bayesian optimization](https://proceedings.mlr.press/v97/astudillo19a.html)** where the overall function that we wish to optimize is a cheap-to-evaluate (and known) function of the outputs. In general, we expect that using more information about the function should yield improved performance when attempting to optimize it, particularly if the metric function itself is quickly varying.\n",
    "\n",
    "See [the composite BO tutorial w/ HOGP](https://github.com/meta-pytorch/botorch/blob/main/tutorials/composite_bo_with_hogp/composite_bo_with_hogp.ipynb) for a more technical introduction. In general, we suggest using MTGPs for unstructured task outputs and the HOGP for matrix / tensor structured outputs.\n",
    "\n",
    "\n",
    "We will use a Multi-Task Gaussian process ([MTGP](https://papers.nips.cc/paper/2007/hash/66368270ffd51418ec58bd793f2d9b1b-Abstract.html)) with an ICM kernel to model all of the outputs in this problem. MTGPs can be easily accessed in Botorch via the `botorch.models.KroneckerMultiTaskGP` model class (for the \"block design\" case of fully observed outputs at all inputs). Given $T$ tasks (outputs) and $n$ data points, they assume that the responses, $Y \\sim \\mathbb{R}^{n \\times T},$ are distributed as $\\text{vec}(Y) \\sim \\mathcal{N}(f, D)$ and $f \\sim \\mathcal{GP}(\\mu_{\\theta}, K_{XX} \\otimes K_{T}),$ where $D$ is a (diagonal) noise term."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install dependencies if we are running in colab\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "    %pip install botorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "customOutput": null,
    "executionStartTime": 1678932909107,
    "executionStopTime": 1678932912073,
    "jupyter": {
     "outputs_hidden": false
    },
    "originalKey": "8871a990-f29b-45c2-b378-ac2befef0a1f",
    "requestMsgId": "e32e501e-2fe4-4b41-b1fd-4a3cf218833e"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "\n",
    "import torch\n",
    "from botorch.acquisition.logei import qLogExpectedImprovement\n",
    "from botorch.acquisition.objective import GenericMCObjective\n",
    "from botorch.models import KroneckerMultiTaskGP\n",
    "from botorch.optim import optimize_acqf\n",
    "from botorch.sampling.normal import IIDNormalSampler\n",
    "\n",
    "from botorch.test_functions import Hartmann\n",
    "from gpytorch.mlls import ExactMarginalLogLikelihood\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "SMOKE_TEST = os.environ.get(\"SMOKE_TEST\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "originalKey": "de394597-4088-4f32-94c7-7c611876eebc",
    "showInput": false
   },
   "source": [
    "### Set device, dtype and random seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "customOutput": null,
    "executionStartTime": 1678932914757,
    "executionStopTime": 1678932914766,
    "jupyter": {
     "outputs_hidden": false
    },
    "originalKey": "015a9aab-e3a5-4d09-bd5d-209f9b41cbdb",
    "requestMsgId": "00e96c89-1a35-4c00-852c-1f5dda614d9a"
   },
   "outputs": [],
   "source": [
    "torch.random.manual_seed(10)\n",
    "\n",
    "tkwargs = {\n",
    "    \"dtype\": torch.double,\n",
    "    \"device\": torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\"),\n",
    "}"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "originalKey": "f151402e-238e-4c5c-be43-55a73f07664e",
    "showInput": false
   },
   "source": [
    "### Problem Definition\n",
    "\n",
    "The function that we wish to optimize is based off of a contextual version of the Hartmann-6 test function, where following [Feng et al, NeurIPS, '20](https://proceedings.neurips.cc/paper/2020/hash/faff959d885ec0ecf70741a846c34d1d-Abstract.html) we convert the sixth task dimension into a task indicator. Here we assume that we evaluate all contexts at once."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "customOutput": null,
    "executionStartTime": 1678932917176,
    "executionStopTime": 1678932917215,
    "jupyter": {
     "outputs_hidden": false
    },
    "originalKey": "48fabc12-b1ee-4b88-aa97-9ce4bfe83fd4",
    "requestMsgId": "0429d2a2-cc39-4838-9b87-f3eebd49e140"
   },
   "outputs": [],
   "source": [
    "from torch import Tensor\n",
    "\n",
    "\n",
    "class ContextualHartmann6(Hartmann):\n",
    "    def __init__(self, num_tasks: int = 20, noise_std=None, negate=False):\n",
    "        super().__init__(dim=6, noise_std=noise_std, negate=negate)\n",
    "        self.task_range = torch.linspace(0, 1, num_tasks).unsqueeze(-1)\n",
    "        self.dim = 5\n",
    "        self._bounds = [(0.0, 1.0) for _ in range(self.dim)]\n",
    "        self.bounds = torch.tensor(self._bounds).t()\n",
    "\n",
    "    def _evaluate_true(self, X: Tensor) -> Tensor:\n",
    "        batch_X = X.unsqueeze(-2)\n",
    "        batch_dims = X.ndim - 1\n",
    "\n",
    "        expanded_task_range = self.task_range\n",
    "        for _ in range(batch_dims):\n",
    "            expanded_task_range = expanded_task_range.unsqueeze(0)\n",
    "        task_range = expanded_task_range.repeat(*X.shape[:-1], 1, 1).to(X)\n",
    "        concatenated_X = torch.cat(\n",
    "            (\n",
    "                batch_X.repeat(*[1] * batch_dims, self.task_range.shape[0], 1),\n",
    "                task_range,\n",
    "            ),\n",
    "            dim=-1,\n",
    "        )\n",
    "        return super()._evaluate_true(concatenated_X)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "originalKey": "a2fbe3e8-98cf-49e0-9891-8d6009790955",
    "showInput": false
   },
   "source": [
    "We use `GenericMCObjective` to define the differentiable function that we are optimizing. Here, it is defined as \n",
    "$$g(f) = \\sum_{i=1}^T \\cos(f_i^2 + f_i w_i)$$\n",
    "where $w$ is a weight vector (drawn randomly once at the start of the optimization). As this function is a non-linear function of the outputs $f,$ we cannot compute acquisition functions via computation of the posterior mean and variance, but rather have to compute posterior samples and evaluate acquisitions with Monte Carlo sampling. \n",
    "\n",
    "For greater than $10$ or so tasks, it is computationally challenging to sample the posterior over all tasks jointly using conventional approaches, except that [Maddox et al, '21](https://arxiv.org/abs/2106.12997) have devised an efficient method for exploiting the structure in the posterior distribution of the MTGP, enabling efficient MC based optimization of objectives using MTGPs. In this tutorial, we choose 6  contexts/tasks for demonstration. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "customOutput": null,
    "executionStartTime": 1678932920811,
    "executionStopTime": 1678932929399,
    "jupyter": {
     "outputs_hidden": false
    },
    "originalKey": "5938fb26-f773-4f68-a389-c0630e0687eb",
    "requestMsgId": "80cec36b-8b73-4c09-9511-f79f59df0af8"
   },
   "outputs": [],
   "source": [
    "num_tasks = 6\n",
    "problem = ContextualHartmann6(num_tasks=num_tasks, noise_std=0.001, negate=True).to(**tkwargs)\n",
    "\n",
    "# we choose num_tasks random weights\n",
    "weights = torch.randn(num_tasks, **tkwargs)\n",
    "\n",
    "\n",
    "def callable_func(samples, X=None):\n",
    "    res = -torch.cos((samples**2) + samples * weights)\n",
    "    return res.sum(dim=-1)\n",
    "\n",
    "\n",
    "objective = GenericMCObjective(callable_func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "customOutput": null,
    "executionStartTime": 1678932929407,
    "executionStopTime": 1678932929411,
    "jupyter": {
     "outputs_hidden": false
    },
    "originalKey": "5a85661a-3668-4128-b5f4-8150dbcdce7d",
    "requestMsgId": "d1742d93-d229-4400-b3d9-55e5fee5eed6"
   },
   "outputs": [],
   "source": [
    "bounds = problem.bounds"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "originalKey": "eff2dc86-0797-40a7-962d-04e8c539c21a",
    "showInput": false
   },
   "source": [
    "## BO Loop"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "originalKey": "53414433-e097-4556-90e9-e99e35cdb390",
    "showInput": false
   },
   "source": [
    "Set environmental parameters, we use 20 initial data points and optimize for 20 steps with a batch size of 3 candidate points at each evaluation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "customOutput": null,
    "executionStartTime": 1678932929437,
    "executionStopTime": 1678932929440,
    "jupyter": {
     "outputs_hidden": false
    },
    "originalKey": "51f24269-b5de-4dd8-b2b4-18d269f3f3c2",
    "requestMsgId": "63c971a6-7e24-4b8a-8ec8-15d0ef21f6ff"
   },
   "outputs": [],
   "source": [
    "if SMOKE_TEST:\n",
    "    n_init = 5\n",
    "    n_steps = 1\n",
    "    batch_size = 2\n",
    "    num_samples = 4\n",
    "    # For L-BFGS inner optimization loop\n",
    "    MAXITER = 10\n",
    "else:\n",
    "    n_init = 10\n",
    "    n_steps = 10\n",
    "    batch_size = 3\n",
    "    num_samples = 64\n",
    "    MAXITER = 200"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "customOutput": null,
    "executionStartTime": 1678932930875,
    "executionStopTime": 1678934072848,
    "jupyter": {
     "outputs_hidden": false
    },
    "originalKey": "de8d4079-6c17-4041-9ec4-9ebd09d46cd7",
    "requestMsgId": "34279809-d637-4ea5-99c2-aaa3957c37ba"
   },
   "outputs": [],
   "source": [
    "from botorch.fit import fit_gpytorch_mll"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, run the optimization loop.\n",
    "\n",
    "Warning... this optimization loop can take a while, especially on the CPU. We compare to random sampling. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "originalKey": "b567ab22-f3b4-41e3-aaae-75e6fe820cfc",
    "showInput": false
   },
   "outputs": [],
   "source": [
    "# New version\n",
    "torch.manual_seed(0)\n",
    "\n",
    "init_x = (bounds[1] - bounds[0]) * torch.rand(\n",
    "    n_init, bounds.shape[1], **tkwargs\n",
    ") + bounds[0]\n",
    "\n",
    "init_y = problem(init_x)\n",
    "\n",
    "mtgp_train_x, mtgp_train_y = init_x, init_y\n",
    "rand_x, rand_y = init_x, init_y\n",
    "\n",
    "best_value_mtgp = objective(init_y).max()\n",
    "best_random = best_value_mtgp\n",
    "\n",
    "for iteration in range(n_steps):\n",
    "    # we empty the cache to clear memory out\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "    # MTGP\n",
    "    mtgp_t0 = time.monotonic()\n",
    "    mtgp = KroneckerMultiTaskGP(mtgp_train_x, mtgp_train_y)\n",
    "    mtgp_mll = ExactMarginalLogLikelihood(mtgp.likelihood, mtgp)\n",
    "    fit_gpytorch_mll(mll=mtgp_mll, optimizer_kwargs={\"options\": {\"maxiter\": 50}})\n",
    "\n",
    "    sampler = IIDNormalSampler(sample_shape=torch.Size([num_samples]))\n",
    "    mtgp_acqf = qLogExpectedImprovement(\n",
    "        model=mtgp,\n",
    "        best_f=best_value_mtgp,\n",
    "        sampler=sampler,\n",
    "        objective=objective,\n",
    "    )\n",
    "    new_mtgp_x, _ = optimize_acqf(\n",
    "        acq_function=mtgp_acqf,\n",
    "        bounds=bounds,\n",
    "        q=batch_size,\n",
    "        num_restarts=10,\n",
    "        raw_samples=512,  # used for initialization heuristic\n",
    "        options={\"batch_limit\": 5, \"maxiter\": MAXITER, \"init_batch_limit\": 5},\n",
    "    )\n",
    "    mtgp_train_x = torch.cat((mtgp_train_x, new_mtgp_x), dim=0)\n",
    "    mtgp_train_y = torch.cat((mtgp_train_y, problem(new_mtgp_x)), dim=0)\n",
    "    best_value_mtgp = objective(mtgp_train_y).max()\n",
    "    mtgp_t1 = time.monotonic()\n",
    "\n",
    "    # rand\n",
    "    new_rand_x = (bounds[1] - bounds[0]) * torch.rand(\n",
    "        batch_size, bounds.shape[1], **tkwargs\n",
    "    ) + bounds[0]\n",
    "    rand_x = torch.cat((rand_x, new_rand_x))\n",
    "    rand_y = torch.cat((rand_y, problem(new_rand_x)))\n",
    "    best_random = objective(rand_y).max()\n",
    "\n",
    "    print(\n",
    "        f\"\\nBatch {iteration:>2}: best_value (random, mtgp) = \"\n",
    "        f\"({best_random:>4.2f}, {best_value_mtgp:>4.2f}, \"\n",
    "        f\"mtgp time = {mtgp_t1-mtgp_t0:>4.2f}\",\n",
    "        end=\"\",\n",
    "    )\n",
    "\n",
    "objectives = {\n",
    "    \"MGTP\": objective(mtgp_train_y).detach().cpu(),\n",
    "    \"Random\": objective(rand_y).detach().cpu(),\n",
    "}"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "originalKey": "b567ab22-f3b4-41e3-aaae-75e6fe820cfc",
    "showInput": false
   },
   "source": [
    "### Plot Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "customOutput": null,
    "executionStartTime": 1678934072894,
    "executionStopTime": 1678934073463,
    "jupyter": {
     "outputs_hidden": false
    },
    "originalKey": "8ec120a6-bc85-42cc-9c59-9963964a0da3",
    "requestMsgId": "b1f7777e-e686-4325-a0e0-d0c8578ec106"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "originalKey": "3585efc5-bcb5-49f0-9fe1-e50d2deccfd2",
    "showInput": false
   },
   "source": [
    "Finally, we plot the results. MTGP will outperform the random baseline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {\n",
    "    k: t[n_init:].cummax(0).values for k, t in objectives.items()\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for name, vals in results.items():\n",
    "    plt.plot(vals, label=name)\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "custom": {
   "cells": [],
   "metadata": {
    "custom": {
     "cells": [],
     "metadata": {
      "fileHeader": "",
      "isAdHoc": false,
      "kernelspec": {
       "display_name": "ae",
       "language": "python",
       "name": "bento_kernel_ae"
      },
      "language_info": {
       "codemirror_mode": {
        "name": "ipython",
        "version": 3
       },
       "file_extension": ".py",
       "mimetype": "text/x-python",
       "name": "python",
       "nbconvert_exporter": "python",
       "pygments_lexer": "ipython3",
       "version": "3.10.8"
      }
     },
     "nbformat": 4,
     "nbformat_minor": 2
    },
    "fileHeader": "",
    "indentAmount": 2,
    "isAdHoc": false,
    "kernelspec": {
     "display_name": "python3",
     "language": "python",
     "name": "python3"
    },
    "language_info": {
     "name": "plaintext"
    }
   },
   "nbformat": 4,
   "nbformat_minor": 2
  },
  "indentAmount": 2,
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
